package org.bjf.aop;

import com.alibaba.fastjson2.JSON;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.bjf.exception.CommMsgCode;
import org.bjf.exception.ServiceException;
import org.bjf.modules.core.web.core.LoginInfo;
import org.bjf.modules.core.web.core.ThreadContext;
import org.bjf.modules.user.enums.UserRedisKey;
import org.bjf.utils.ExceptionAssert;
import org.bjf.utils.RedisUtil;
import org.bjf.utils.TokenUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;

import javax.servlet.http.HttpServletRequest;
import java.util.Date;
import java.util.Map;

/**
 * websocket 拦截器
 *
 * @author bjf
 */
@Slf4j
@Component
public class WebsocketInterceptor implements HandshakeInterceptor {

    @Autowired
    private RedisUtil redis;

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
                                   WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        // token校验
        String accessToken = getAccessToken(request);
        if (StringUtils.isBlank(accessToken) || !TokenUtil.verifyToken(accessToken)) {
            log.error("invalid websocket token:{}", accessToken);
            throw new ServiceException(CommMsgCode.UNAUTHORIZED);
        }
        LoginInfo loginInfo = getLoginInfo(accessToken);
        // 登录用户信息放到ThreadLocal
        loginInfo.setLastTime(new Date());
        ThreadContext.setLoginInfo(loginInfo);

        log.info("websocket loginInfo：" + JSON.toJSONString(loginInfo));

        return Boolean.TRUE;
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
                               WebSocketHandler wsHandler, Exception exception) {

    }

    private LoginInfo getLoginInfo(String accessToken) {
        // redis取登录用户信息
        String userKey = UserRedisKey.TOKEN_API.as(accessToken);
        LoginInfo loginInfo = redis.getObj(userKey);
        ExceptionAssert.notNull(loginInfo, CommMsgCode.UNAUTHORIZED);
        // 更新reids过期时间
        redis.setObj(userKey, loginInfo, 86400 * 15);

        return loginInfo;
    }

    private String getAccessToken(ServerHttpRequest request) {
        String accessToken = null;
        if (request instanceof ServletServerHttpRequest) {
            ServletServerHttpRequest servletServerHttpRequest = (ServletServerHttpRequest) request;
            HttpServletRequest servletRequest = servletServerHttpRequest.getServletRequest();
            accessToken = servletRequest.getHeader("x-websocket-token");
            if (StringUtils.isBlank(accessToken)) {
                accessToken = servletRequest.getParameter("x-websocket-token");
            }
        }

        return accessToken;
    }

}
